from core.utils import Utils
from config import Config
import numpy as np
import os
import cv2
class ProposalLayer:
    def cal_evaluation_index(path_images,path_labels,proposals,image_info):
        precision = 0.
        recall = 0.
        fscore = 0.
        images_name = image_info['image_name']
        width = image_info['width']
        height = image_info['height']
        for index,name in enumerate(images_name):
            img = cv2.imread(os.path.join(path_images, name))
            im_scale = [height/img.shape[0], width/img.shape[1]]
            filename_label = name.split('.')[0] + '.txt'
            with open(os.path.join(path_labels, filename_label), 'r', encoding='utf-8') as f:
                lines = f.readlines()
            gt_boxes = np.zeros((len(lines), 4), dtype=np.float32)
            for x, line in enumerate(lines):
                l = line.split('\t')
                gt_boxes[x][0] = float(l[0]) * im_scale[1]
                gt_boxes[x][1] = float(l[1]) * im_scale[0]
                gt_boxes[x][2] = float(l[2]) * im_scale[1]
                gt_boxes[x][3] = float(l[3]) * im_scale[0]
            inds = np.where(proposals[:,0]==index)[0]

            proposal = proposals[inds,:]
            if(Config.USE_C):
                overlaps = Utils.c_cal_overlaps(proposal[:,2:],gt_boxes)
            else:
                overlaps = Utils.cal_overlaps(proposals[:,2:],gt_boxes)
            # 求每个proposal的最大overlap
            argmax_overlaps = overlaps.argmax(axis=1)
            max_overlaps = overlaps[np.arange(overlaps.shape[0]), argmax_overlaps]
            max_overlaps = max_overlaps[max_overlaps>0.7]

            precision = precision + max_overlaps.shape[0]/overlaps.shape[0]
            recall = recall + max_overlaps.shape[0]/gt_boxes.shape[0]
        precision = precision/len(images_name)
        recall = recall/len(images_name)
        fscore = 2*precision*recall/(precision+recall+0.0001)
        return precision,recall,fscore

    def c_generate_proposals(rpn_cls_prob,rpn_bbox_pred,image_info,feat_stride):
        featuremap_h = image_info['featuremap_h']
        featuremap_w = image_info['featuremap_w']
        height = image_info['height']
        width = image_info['width']
        images_name = image_info['image_name']
        batchsize = len(images_name)

        pre_nms_topn = 5000
        post_nms_topn = 2000
        nms_thresh = 0.3
        min_size = 4
        num_anchors = 10

        # 获取每一个anchor为背景的概率，去掉前景
        scores = np.reshape(
            np.reshape(rpn_cls_prob, [batchsize, featuremap_h, featuremap_w, num_anchors, 2])[:, :, :, :, 1],
            [batchsize, featuremap_h, featuremap_w, num_anchors])

        # 模型输出的pred是相对值，需要进一步处理成真实图像中的坐标
        bbox_deltas = rpn_bbox_pred


        anchors = Utils.c_generate_all_anchors(featuremap_w,featuremap_h,16)


        # 将一个图片的anchor扩展到每一个图片上
        total_anchors = np.zeros(shape=(batchsize, featuremap_h*featuremap_w*10, 4))
        total_anchors[:, :, :] = anchors
        total_anchors = total_anchors.reshape(-1, 4)

        # 将一个图片的anchor扩展到每一个图片上
        total_anchors = np.zeros(shape=(batchsize, featuremap_h*featuremap_w*10, 4))
        total_anchors[:, :, :] = anchors
        total_anchors = total_anchors.reshape(-1, 4)

        bbox_deltas = bbox_deltas.reshape((-1, 4))
        scores = scores.reshape((-1, 1))

        if (Config.USE_C):
            proposals = Utils.c_bbox_transform_inv(total_anchors, bbox_deltas)
        else:
            proposals = Utils.bbox_transform_inv(total_anchors, bbox_deltas)
        proposals = proposals.reshape(batchsize, -1, 4)
        scores = scores.reshape(batchsize, -1)
        ret = []
        for index in range(proposals.shape[0]):
            proposal = proposals[index]
            score = scores[index]

            proposal = ProposalLayer._clip_boxes(proposal, [height, width])
            keep = ProposalLayer._filter_boxes(proposal, min_size)
            proposal = proposal[keep, :]
            score = score[keep]
            # score按得分的高低进行排序
            order = score.ravel().argsort()[::-1]
            order = order[:pre_nms_topn]

            proposal = proposal[order, :]
            score = score[order]
            if (Config.USE_C):
                keep = Utils.c_nms(np.hstack((proposal, score.reshape(-1, 1))), nms_thresh)
            else:
                keep = Utils.nms(np.hstack((proposal, score.reshape(-1, 1))), nms_thresh)
            # print(len(keep))
            keep = keep[:post_nms_topn]
            proposal = proposal[keep, :]
            score = score[keep]
            no = np.zeros(shape=(score.shape[0], 1))
            no[:] = index
            ret.extend(np.hstack((no.reshape(-1, 1), score.reshape(-1, 1), proposal)))
        ret = np.array(ret)
        return ret



    def generate_proposals(rpn_cls_prob,rpn_bbox_pred,image_info,feat_stride):

        anchors = Utils.generate_anchors()
        num_anchors = anchors.shape[0]

        featuremap_h = image_info['featuremap_h']
        featuremap_w = image_info['featuremap_w']
        height = image_info['height']
        width = image_info['width']
        images_name = image_info['image_name']
        batchsize = len(images_name)

        pre_nms_topn = 5000
        post_nms_topn = 2000
        nms_thresh = 0.3
        min_size = 4

        #获取每一个anchor为背景的概率，去掉前景
        scores = np.reshape(np.reshape(rpn_cls_prob, [batchsize, featuremap_h, featuremap_w, num_anchors, 2])[:, :, :, :, 1],
                            [batchsize,featuremap_h,featuremap_w,num_anchors])

        # 模型输出的pred是相对值，需要进一步处理成真实图像中的坐标
        bbox_deltas = rpn_bbox_pred

        # 这里得到的anchor就是整张图像上的所有anchor
        shift_x = np.arange(0, featuremap_w)*feat_stride
        shift_y = np.arange(0, featuremap_h)*feat_stride
        shift_x, shift_y = np.meshgrid(shift_x, shift_y)
        shifts = np.vstack((shift_x.ravel(), shift_y.ravel(),shift_x.ravel(), shift_y.ravel())).transpose()
        A = num_anchors
        K = shifts.shape[0]
        anchors = anchors.reshape((1, A, 4)) +shifts.reshape((1, K, 4)).transpose((1, 0, 2))
        anchors = anchors.reshape((K * A, 4))

        # 将一个图片的anchor扩展到每一个图片上
        total_anchors = np.zeros(shape=(batchsize,K*A,4))
        total_anchors[:,:,:] = anchors
        total_anchors = total_anchors.reshape(-1,4)

        bbox_deltas = bbox_deltas.reshape((-1, 4))
        scores = scores.reshape((-1, 1))

        if(Config.USE_C):
            proposals = Utils.c_bbox_transform_inv(total_anchors,bbox_deltas)
        else:
            proposals = Utils.bbox_transform_inv(total_anchors, bbox_deltas)
        proposals = proposals.reshape(batchsize,-1,4)
        scores = scores.reshape(batchsize,-1)
        ret = []
        for index in range(proposals.shape[0]):
            proposal = proposals[index]
            score = scores[index]

            proposal = ProposalLayer._clip_boxes(proposal,[height, width])
            keep = ProposalLayer._filter_boxes(proposal,min_size)
            proposal = proposal[keep, :]
            score = score[keep]
            # score按得分的高低进行排序
            order = score.ravel().argsort()[::-1]
            order = order[:pre_nms_topn]

            proposal = proposal[order, :]
            score = score[order]
            if(Config.USE_C):
                keep = Utils.c_nms(np.hstack((proposal, score.reshape(-1,1))), nms_thresh)
            else:
                keep = Utils.nms(np.hstack((proposal, score.reshape(-1, 1))), nms_thresh)
            #print(len(keep))
            keep = keep[:post_nms_topn] 
            proposal = proposal[keep, :]
            score = score[keep]
            no = np.zeros(shape=(score.shape[0],1))
            no[:] = index
            ret.extend(np.hstack((no.reshape(-1,1),score.reshape(-1,1),proposal)))
        ret = np.array(ret)
        return ret

    def _filter_boxes(boxes, min_size):
        """Remove all boxes with any side smaller than min_size."""
        ws = boxes[:, 2] - boxes[:, 0] + 1
        hs = boxes[:, 3] - boxes[:, 1] + 1
        keep = np.where((ws >= min_size) & (hs >= min_size))[0]
        return keep

    def _clip_boxes(boxes, im_shape):
        """Clip boxes to image boundaries."""
        # x1 >= 0
        boxes[:,0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0)
        # y1 >= 0
        boxes[:,1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0)
        # x2 < im_shape[1]
        boxes[:,2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0)
        # y2 < im_shape[0]
        boxes[:,3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0)
        return boxes

if __name__ == "__main__":
   # cls = np.random.random((2,5,5,2))
   # pred = np.random.random((2,5,5,4))
   # p = ProposalLayer('1.txt')
   # p.generate_proposals(cls,pred,)
   pass
